/*
    Copyright (c) 2013, Taiga Nomi and the respective contributors
    All rights reserved.

    Use of this source code is governed by a BSD-style license that can be found
    in the LICENSE file.
*/
#pragma once

#include <vector>

#include "tiny_dnn/util/util.h"

namespace tiny_dnn
{

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// 自定义 yolt_error
/////////////////////////////////////////////////////////////////////////////////////////////////////////

class yolt_error
{
  public:
    static float_t f(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());

        float_t d{0.0};
        const float_t w_coor = 5.0f;
        const float_t w_iou = 0.5f;
        float_t curr_error{0.0};

        for (size_t i = 1; i <= y.size(); ++i)
        {
            curr_error = (y[i - 1] - t[i - 1]) * (y[i - 1] - t[i - 1]);
            curr_error = (i % 5 == 0) ? w_iou * curr_error : w_coor * curr_error;
            d += curr_error;
        }

        return d / static_cast<float_t>(y.size());
    }

    static vec_t df(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        vec_t d(t.size());
        float_t factor = float_t(2) / static_cast<float_t>(t.size());
        const float_t w_coor = 5.0f;
        const float_t w_iou = 0.5f;
        float_t curr_error{0.0};

        for (size_t i = 1; i <= y.size(); ++i)
        {
            curr_error = y[i - 1] - t[i - 1];
            curr_error = (i % 5 == 0) ? w_iou * curr_error : w_coor * curr_error;
            d[i - 1] = factor * curr_error;
        }

        return d;
    }
};
/////////////////////////////////////////////////////////////////////////////////////////////////////////

// mean-squared-error loss function for regression
class mse
{
  public:
    static float_t f(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        float_t d{0.0};

        for (size_t i = 0; i < y.size(); ++i)
            d += (y[i] - t[i]) * (y[i] - t[i]);

        return d / static_cast<float_t>(y.size());
    }

    static vec_t df(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        vec_t d(t.size());
        float_t factor = float_t(2) / static_cast<float_t>(t.size());

        for (size_t i = 0; i < y.size(); ++i)
            d[i] = factor * (y[i] - t[i]);

        return d;
    }
};

// absolute loss function for regression
class absolute
{
  public:
    static float_t f(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        float_t d{0};

        for (size_t i = 0; i < y.size(); ++i)
            d += std::abs(y[i] - t[i]);

        return d / static_cast<float_t>(y.size());
    }

    static vec_t df(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        vec_t d(t.size());
        float_t factor = float_t(1) / static_cast<float_t>(t.size());

        for (size_t i = 0; i < y.size(); ++i)
        {
            float_t sign = y[i] - t[i];
            if (sign < float_t{0.f})
                d[i] = -factor;
            else if (sign > float_t{0.f})
                d[i] = factor;
            else
                d[i] = {0};
        }

        return d;
    }
};

// absolute loss with epsilon range for regression
// epsilon range [-eps, eps] with eps = 1./fraction
template <int fraction>
class absolute_eps
{
  public:
    static float_t f(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        float_t d{0};
        const float_t eps = float_t(1) / fraction;

        for (size_t i = 0; i < y.size(); ++i)
        {
            float_t diff = std::abs(y[i] - t[i]);
            if (diff > eps)
                d += diff;
        }
        return d / static_cast<float_t>(y.size());
    }

    static vec_t df(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        vec_t d(t.size());
        const float_t factor = float_t(1) / static_cast<float_t>(t.size());
        const float_t eps = float_t(1) / fraction;

        for (size_t i = 0; i < y.size(); ++i)
        {
            float_t sign = y[i] - t[i];
            if (sign < -eps)
                d[i] = -factor;
            else if (sign > eps)
                d[i] = factor;
            else
                d[i] = 0.f;
        }
        return d;
    }
};

// cross-entropy loss function for (multiple independent) binary classifications
class cross_entropy
{
  public:
    static float_t f(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        float_t d{0};

        for (size_t i = 0; i < y.size(); ++i)
            d += -t[i] * std::log(y[i]) -
                 (float_t(1) - t[i]) * std::log(float_t(1) - y[i]);

        return d;
    }

    static vec_t df(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        vec_t d(t.size());

        for (size_t i = 0; i < y.size(); ++i)
            d[i] = (y[i] - t[i]) / (y[i] * (float_t(1) - y[i]));

        return d;
    }
};

// cross-entropy loss function for multi-class classification
class cross_entropy_multiclass
{
  public:
    static float_t f(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        float_t d{0.0};

        for (size_t i = 0; i < y.size(); ++i)
            d += -t[i] * std::log(y[i]);

        return d;
    }

    static vec_t df(const vec_t &y, const vec_t &t)
    {
        assert(y.size() == t.size());
        vec_t d(t.size());

        for (size_t i = 0; i < y.size(); ++i)
            d[i] = -t[i] / y[i];

        return d;
    }
};

template <typename E>
vec_t gradient(const vec_t &y, const vec_t &t)
{
    assert(y.size() == t.size());
    return E::df(y, t);
}

template <typename E>
std::vector<vec_t> gradient(const std::vector<vec_t> &y,
                            const std::vector<vec_t> &t)
{
    std::vector<vec_t> grads(y.size());

    assert(y.size() == t.size());

    for (size_t i = 0; i < y.size(); i++)
        grads[i] = gradient<E>(y[i], t[i]);

    return grads;
}

inline void apply_cost_if_defined(std::vector<vec_t> &sample_gradient,
                                  const std::vector<vec_t> &sample_cost)
{
    if (sample_gradient.size() == sample_cost.size())
    {
        // @todo consider adding parallelism
        const size_t channel_count = sample_gradient.size();
        for (size_t channel = 0; channel < channel_count; ++channel)
        {
            if (sample_gradient[channel].size() == sample_cost[channel].size())
            {
                const size_t element_count = sample_gradient[channel].size();

                // @todo optimize? (use AVX or so)
                for (size_t element = 0; element < element_count; ++element)
                {
                    sample_gradient[channel][element] *= sample_cost[channel][element];
                }
            }
        }
    }
}

// gradient for a minibatch
template <typename E>
std::vector<tensor_t> gradient(const std::vector<tensor_t> &y,
                               const std::vector<tensor_t> &t,
                               const std::vector<tensor_t> &t_cost)
{
    const size_t sample_count = y.size();
    const size_t channel_count = y[0].size();

    std::vector<tensor_t> gradients(sample_count);

    CNN_UNREFERENCED_PARAMETER(channel_count);
    assert(y.size() == t.size());
    assert(t_cost.empty() || t_cost.size() == t.size());

    // @todo add parallelism
    for (size_t sample = 0; sample < sample_count; ++sample)
    {
        assert(y[sample].size() == channel_count);
        assert(t[sample].size() == channel_count);
        assert(t_cost.empty() || t_cost[sample].empty() ||
               t_cost[sample].size() == channel_count);

        gradients[sample] = gradient<E>(y[sample], t[sample]);

        if (sample < t_cost.size())
        {
            apply_cost_if_defined(gradients[sample], t_cost[sample]);
        }
    }

    return gradients;
}

} // namespace tiny_dnn
